Skip to content

Conversation

@shewu-quic
Copy link
Collaborator

Background

We observed that quantizing and compiling the original sha model requires a significant amount of time. Switching to the mha model speeds up this process. Therefore, we investigated whether converting the mha model after quantization is feasible. However, we cannot perform this conversion during the to_edge transformation, as splitting the convolution weights to sha would require modifying the state_dict, which is not permitted at that stage. Therefore, we decided to apply this pass during qnn_preprocess.

Summary:

  • Integrated mha into sha pass and implemented it in qnn_preprocess
  • Refactored mha in static llama
    • Included spin quant r3 support and masked softmax for MHA model in static llama
    • Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance.
  • Deprecated ShiftPointer kv updater mode
    • Since each layer now has its own kv cache, the v cache no longer benefits from ShiftPointer, which previously avoided copying the new v cache to the input v cache. To prevent user confusion, ShiftPointer mode has been deprecated
  • Applied the correct input template for smollm2 135m
  • Correct the quantization annotation for reshape
  • Remove outdated code from CanonicalizeConv

Results

Follow README setting, test on SM8750 with QNN 2.37. Compared the new pass convert_mha_to_sha with original sha structure

image

@shewu-quic shewu-quic requested a review from cccclai as a code owner October 29, 2025 06:56
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15438

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (4 Unrelated Failures)

As of commit 601c14c with merge base ca4c575 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2025
@shewu-quic
Copy link
Collaborator Author

@pytorchbot label "release notes: qualcomm"

@pytorch-bot pytorch-bot bot added the release notes: qualcomm Changes to the Qualcomm backend delegate label Oct 29, 2025
@shewu-quic
Copy link
Collaborator Author

Hi @cccclai,
This PR is to migrate mha2sha transformation from source level to a pass which apply on qnn_preprocess. It can significantly improve lowering time including quantization and compilation time.
Could you please take a look?

Thanks

@cccclai
Copy link
Contributor

cccclai commented Oct 31, 2025

Hi, since it's a really big change, and MHA2SHA pass seems complicated, can you add a test for the pass here https://github.com/pytorch/executorch/blob/main/backends/qualcomm/tests/test_passes.py passes can be fragile, so I'm trying to make sure we have it cover in tests

@shewu-quic shewu-quic force-pushed the dev1/hutton/add_mha_to_sha_pass branch from 18e7db1 to 0a666d2 Compare November 3, 2025 10:00
@shewu-quic
Copy link
Collaborator Author

Hi, since it's a really big change, and MHA2SHA pass seems complicated, can you add a test for the pass here https://github.com/pytorch/executorch/blob/main/backends/qualcomm/tests/test_passes.py passes can be fragile, so I'm trying to make sure we have it cover in tests

Thanks for pointing up. I have added a test case to check the functionality of MHA2SHA.

if n.target == exir_ops.edge.aten.convolution.default
]
# Check graph structure: WQ, WK, WV should be converted to SHA
self.assertTrue(len(conv_nodes) == 25, "Convolution nodes should be splited")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test! Is it possible to check if the numeric are the same?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I have added it. Thanks!

Summary:
- Integrated mha into sha pass and implemented it in qnn_preprocess
- Refactored mha in static llama
  - Added support for masked softmax
  - Included spin quant r3 support
  - Combined the n_heads key-value cache into a single cache for each
    layer to decrease the number of inputs and outputs, which enhances
performance.
- Deprecated ShiftPointer kv updater mode
  - Since each layer now has its own kv cache, the v cache no longer
    benefits from ShiftPointer, which previously avoided copying the new
v cache to the input v cache. To prevent user confusion, ShiftPointer
mode has been deprecated
- Applied the correct input template for smollm2 135m
- Corrected the quantization annotation for reshape
- Remove outdated code from CanonicalizeConv
@shewu-quic shewu-quic force-pushed the dev1/hutton/add_mha_to_sha_pass branch from 0a666d2 to 0b33455 Compare November 4, 2025 05:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: qualcomm Changes to the Qualcomm backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants